13. 测试
测试训练后的模型
我将演示两种测试模型的好方法:使用测试数据和进行推理。第一种方法和在 CNN 课程中提到的方法相似。在 test_loader
中迭代测试数据,记录测试损失并根据模型预测正确的标签数计算准确率。
计算方法是查看输出的舍入值。输出是一个介于 0-1 之间的 S 型函数输出,所以舍入值将是一个整数,表示概率最大的标签;0 或 1。然后将预测标签与真实标签进行比较;如果匹配,则记录为标签正确的测试影评。
# Get test data loss and accuracy
test_losses = [] # track loss
num_correct = 0
# init hidden state
h = net.init_hidden(batch_size)
net.eval()
# iterate over test data
for inputs, labels in test_loader:
# Creating new variables for the hidden state, otherwise
# we'd backprop through the entire training history
h = tuple([each.data for each in h])
if(train_on_gpu):
inputs, labels = inputs.cuda(), labels.cuda()
# get predicted outputs
output, h = net(inputs, h)
# calculate loss
test_loss = criterion(output.squeeze(), labels.float())
test_losses.append(test_loss.item())
# convert output probabilities to predicted class (0 or 1)
pred = torch.round(output.squeeze()) # rounds to the nearest integer
# compare predictions to true label
correct_tensor = pred.eq(labels.float().view_as(pred))
correct = np.squeeze(correct_tensor.numpy()) if not train_on_gpu else np.squeeze(correct_tensor.cpu().numpy())
num_correct += np.sum(correct)
# -- stats! -- ##
# avg test loss
print("Test loss: {:.3f}".format(np.mean(test_losses)))
# accuracy over all test data
test_acc = num_correct/len(test_loader.dataset)
print("Test accuracy: {:.3f}".format(test_acc))
在下面输出平均测试损失和准确率,即分类正确的项目数除以测试数据总数。
测试损失是 0.516
,准确率约为 81.1%。

测试结果
接下来是最后一个任务了。也就是定义 predict
函数,用模型对任何给定文本影评进行推理。
练习:对测试影评进行推理
你可以将 test_review
更改为任何文本。读一读并判断:是正面还是负面影评?然后看看模型能否正确预测。
练习:请编写一个
predict
函数,参数包括训练过的网络、纯文本影评和序列长度,然后针对正面或负面影评输出自定义描述性句子。
- 你可以使用你已经定义过的任何函数,或定义帮助完成
predict
的辅助函数,但是参数只能包括训练过的网络、文本影评和序列长度。
def predict(net, test_review, sequence_length=200):
''' Prints out whether a give review is predicted to be
positive or negative in sentiment, using a trained model.
params:
net - A trained net
test_review - a review made of normal text and punctuation
sequence_length - the padded length of a review
'''
# print custom response based on whether test_review is pos/neg
请试着自己完成这道练习,然后看看 solution。